{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "除了使用 Sequential 实例来构造模型外；还能使用基于 Module 类来构建模型，让模型构造更加灵活"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 通过继承 Module 类来构造模型\n",
    "Module 类是nn模块里提供的一个模型构造类，是所有神经网络模块的基类，可以通过继承它来定义想要的模型。下面继承Module类构造多层感知机。这里定义的MLP类重载了Module类的__init__函数和forward函数。它们分别用于创建模型参数和定义前向计算。前向计算也即正向传播。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    # 声明带有模型参数的层，这里声明了两个全连接层\n",
    "    def __init__(self, **kwargs):\n",
    "        # 调用MLP父类Module的构造函数来进行必要的初始化。还可以在构造实例时指定其他函数参数\n",
    "        super(MLP, self).__init__(**kwargs)\n",
    "        self.hidden = nn.Linear(784, 256) # 隐藏层\n",
    "        self.act = nn.ReLU() # 激活函数\n",
    "        self.output = nn.Linear(256, 10) # 输出层\n",
    "    \n",
    "    # 定义模型的前向计算，即如何根据输入 X 计算返回所需要的模型输出\n",
    "    def forward(self, x):\n",
    "        a = self.act(self.hidden(x))\n",
    "        return self.output(a)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上的MLP类中无须定义反向传播函数。系统将通过自动求梯度而自动生成反向传播所需的backward函数。\n",
    "\n",
    "可以实例化MLP类得到模型变量net。并传入输入数据 X 做前向计算，net(X)会调用MLP继承自Module类的__call__函数，这个函数将调用MLP类定义的forward函数来完成前向计算。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MLP(\n",
      "  (hidden): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (act): ReLU()\n",
      "  (output): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[-0.2450, -0.1168,  0.0697,  0.1069,  0.0076, -0.1464,  0.0995,  0.2224,\n",
       "          0.0027, -0.1311],\n",
       "        [-0.2553, -0.0101,  0.0929,  0.0990, -0.0201, -0.1263,  0.1837,  0.2261,\n",
       "         -0.0491, -0.2419]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(2, 784)\n",
    "net = MLP()\n",
    "\n",
    "print(net)\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里并没有将Module类命名为Layer（层）或者Model（模型）之类的名字，这是因为该类是一个可供自由组建的部件。它的子类既可以是一个层（如PyTorch提供的Linear类），又可以是一个模型（如这里定义的MLP类），或者是模型的一个部分"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Module的子类\n",
    "PyTorch还实现了继承自Module的可以方便构建模型的子类，例如\n",
    "### Sequential\n",
    "### ModuleList\n",
    "### ModuleDict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1、Sequential类\n",
    "当模型的前向计算为简单串联各个层的计算时，Sequential类可以通过更加简单的方式定义模型。这正是Sequential类的目的：它可以接收一个子模块的有序字典（OrderedDict）或者一系列子模块作为参数来逐一添加Module的实例，而模型的前向计算就是将这些实例按添加的顺序逐一计算。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 为了更清晰了解Sequential类的工作机制，下面实现一个与Sequential类具有相同功能的类\n",
    "\n",
    "class MySequential(nn.Module):\n",
    "    from collections import OrderedDict\n",
    "    def __init__(self, *args):\n",
    "        super(MySequential, self).__init__()\n",
    "        # 如果传入的是一个 OrderedDict\n",
    "        if len(args) == 1 and isinstance(args[0], OrderedDict):\n",
    "            for key, module in args[0].items():\n",
    "                self.add_module(key, module) # add_module方法会将module添加进self._modules(一个OrderedDict)\n",
    "        # 如果传入的是一些Module\n",
    "        else: \n",
    "            for idx, module in enumerate(args):\n",
    "                self.add_module(str(idx), module)\n",
    "    \n",
    "    def forward(self, input):\n",
    "        # self._modules返回一个 OrderedDict，保证会按照成员添加时的顺序遍历成员\n",
    "        for module in self._modules.values():\n",
    "            input = module(input)\n",
    "        return input\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MySequential(\n",
      "  (0): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0141,  0.3379, -0.0142,  0.0492, -0.1428, -0.1336, -0.0488, -0.0529,\n",
       "         -0.1904,  0.0147],\n",
       "        [-0.1838,  0.2701, -0.0717, -0.0825, -0.2447, -0.2257, -0.1582, -0.2357,\n",
       "         -0.2115, -0.0237]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = MySequential(\n",
    "    nn.Linear(784, 256),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(256, 10),\n",
    ")\n",
    "print(net)\n",
    "# print(net._modules.values())\n",
    "x = torch.rand(2, 784)\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2、ModuleList类\n",
    "ModuleLIst接受一个子模块的列表作为输入，然后也可以类似List那样进行append和extend操作："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(in_features=256, out_features=10, bias=True)\n",
      "ModuleList(\n",
      "  (0): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (1): ReLU()\n",
      "  (2): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])\n",
    "net.append(nn.Linear(256, 10)) # 类似List的append操作\n",
    "print(net[-1]) # 类似List的索引访问\n",
    "print(net)\n",
    "# net(torch.zeros(1, 784)) # 会报NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###### Sequential和ModuleList的区别\n",
    "ModuleList仅仅是一个储存各种模块的列表，这些模块之间没有联系也没有顺序（所以不用保证相邻层的输入输出维度匹配），而且没有实现forward功能需要自己实现，所以上面执行net(torch.zeros(1, 784))会报NotImplementedError；而Sequential内的模块需要按照顺序排列，要保证相邻层的输入输出大小相匹配，内部forward功能已经被实现。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ModuleList的出现只是让网络定义前向传播时更加灵活\n",
    "class MyModule(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MyModule, self).__init__()\n",
    "        self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(10)])\n",
    "    def forward(self, x):\n",
    "        # ModuleList can act as an iterable, or be indexed using ints\n",
    "        for i, module in enumerate(self.linears):\n",
    "            x = self.linears[i // 2](x) + module(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ModuleList不同于一般的Python的list，加入到ModuleList里面的所有模块的参数会被自动添加到整个网络中"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "net1:\n",
      "torch.Size([10, 10])\n",
      "torch.Size([10])\n",
      "net2:\n"
     ]
    }
   ],
   "source": [
    "# 对比实验\n",
    "class Module_ModuleList(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Module_ModuleList, self).__init__()\n",
    "        self.linears = nn.ModuleList([nn.Linear(10,10)])\n",
    "\n",
    "class Module_List(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Module_List, self).__init__()\n",
    "        self.linears = [nn.Linear(10,10)]\n",
    "\n",
    "net1 = Module_ModuleList()\n",
    "net2 = Module_List()\n",
    "print('net1:')\n",
    "for param in net1.parameters():\n",
    "    print(param.size())\n",
    "print('net2:')\n",
    "for param in net2.parameters():\n",
    "    print(param)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3、ModuleDict类\n",
    "ModuleDict接受一个子模块的字典作为输入，然后也可以类似字典那样进行添加访问操作"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(in_features=784, out_features=256, bias=True)\n",
      "Linear(in_features=256, out_features=10, bias=True)\n",
      "ModuleDict(\n",
      "  (act): ReLU()\n",
      "  (linear): Linear(in_features=784, out_features=256, bias=True)\n",
      "  (output): Linear(in_features=256, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = nn.ModuleDict({\n",
    "    'linear': nn.Linear(784, 256),\n",
    "    'act': nn.ReLU(),\n",
    "})\n",
    "net['output'] = nn.Linear(256, 10) # 类似字典一样添加\n",
    "print(net['linear']) # 类似字典一样访问\n",
    "print(net.output)\n",
    "print(net)\n",
    "# net(torch.zeros(1, 784)) # 会报NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "和ModuleList一样，ModuleDict实例仅仅是存放了一些模块的字典，并没有定义forward函数，需要自己定义。同样，ModuleDict也与Python的Dict有所不同，ModuleDict里的所有模块的参数会被自动添加到整个网络中。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 构造复杂的模型\n",
    "上面介绍的这些类可以使模型构造更加简单，且不需要定义forward函数，但直接继承Module类可以极大地拓展模型构造的灵活性。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 更加复杂的模型\n",
    "# 可以通过get_constant函数创建训练中不被迭代的参数，即常数参数\n",
    "\n",
    "class FancyMLP(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(FancyMLP, self).__init__(**kwargs)\n",
    "        self.rand_weight = torch.rand((20,20), requires_grad=True) # 不可训练参数（常数参数）\n",
    "        self.linear = nn.Linear(20, 20)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.linear(x)\n",
    "        # 使用创建的常数参数，以及nn.functional中的relu函数和mm函数\n",
    "        x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)\n",
    "        \n",
    "        # 复用全连接层。等价于两个全连接层共享参数\n",
    "        x = self.linear(x)\n",
    "        while x.norm().item() > 1:\n",
    "            x /= 2\n",
    "        if x.norm().item() < 0.8:\n",
    "            x *= 10\n",
    "        return x.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FancyMLP(\n",
      "  (linear): Linear(in_features=20, out_features=20, bias=True)\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(14.4494, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(2, 20)\n",
    "net = FancyMLP()\n",
    "print(net)\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "因为FancyMLP和Sequential类都是Module类的子类，所以可以嵌套调用它们"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): NestMLP(\n",
      "    (net): Sequential(\n",
      "      (0): Linear(in_features=40, out_features=30, bias=True)\n",
      "      (1): ReLU()\n",
      "    )\n",
      "  )\n",
      "  (1): Linear(in_features=30, out_features=20, bias=True)\n",
      "  (2): FancyMLP(\n",
      "    (linear): Linear(in_features=20, out_features=20, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor(-1.6943, grad_fn=<SumBackward0>)"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class NestMLP(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(NestMLP, self).__init__(**kwargs)\n",
    "        self.net = nn.Sequential(nn.Linear(40, 30), nn.ReLU())\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())\n",
    "x = torch.rand(2, 40)\n",
    "print(net)\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "yczlab_3.6",
   "language": "python",
   "name": "yczlab_python3.6"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
